"""
urbansound8k_kd.py
------------------
Teacher–student knowledge distillation with VQCs on UrbanSound8K.

• Audio → 40-coef MFCC  → PCA(10) → angle-encoded on qubits
• Teacher  : 10-qubit EfficientSU2 (reps=2)
• Student  :  6-qubit EfficientSU2 (reps=1)
• Distill  : teacher generates pseudo-labels, student learns from them
"""

# ---------------------------------------------------------------
# 0. Imports
# ---------------------------------------------------------------
import random, warnings, os, pathlib
import numpy as np
import librosa
from datasets import load_dataset

from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

from qiskit_aer import AerSimulator
from qiskit.circuit.library import EfficientSU2
from qiskit_machine_learning.neural_networks import SamplerQNN
from qiskit_machine_learning.algorithms import VQC
from qiskit.utils import QuantumInstance


# ---------------------------------------------------------------
# 1. Distillation “server”
# ---------------------------------------------------------------
class DistillationServer:
    """
    Encapsulates data prep, teacher training, pseudo-labelling,
    student training, and evaluation.
    """
    def __init__(
        self,
        num_teacher_qubits: int = 10,
        num_student_qubits: int = 6,
        pca_components: int = 10,
        seed: int = 123,
        max_clips: int | None = None,  # set e.g. 1500 for a quick demo
    ):
        self.seed = seed
        self.n_classes = 10  # UrbanSound8K has 10 labels
        self.num_teacher_qubits = num_teacher_qubits
        self.num_student_qubits = num_student_qubits
        self.pca_components = pca_components
        self.max_clips = max_clips
        self._rng()

        self._load_data()
        self._build_teacher()
        self._build_student()

    # 1.1  RNG helper
    def _rng(self):
        np.random.seed(self.seed)
        random.seed(self.seed)

    # 1.2  Data: download + MFCC + PCA + angle-scaling
    def _load_data(self):
        print("⇨ Loading UrbanSound8K …")
        ds = load_dataset("urbansound8k", "audio", split="train")
        if self.max_clips:
            ds = ds.shuffle(self.seed).select(range(self.max_clips))

        def mfcc_mean(audio, sr=16_000, n_mfcc=40):
            if audio.shape[-1] != sr:
                audio = librosa.resample(audio, orig_sr=audio.shape[-1], target_sr=sr)
            mfcc = librosa.feature.mfcc(y=audio.astype(np.float32), sr=sr,
                                        n_mfcc=n_mfcc, hop_length=512)
            return mfcc.mean(axis=1)

        mfccs, labels = [], []
        for sample in ds:
            mfccs.append(mfcc_mean(sample["audio"]["array"]))
            labels.append(sample["classID"])

        mfccs = np.stack(mfccs).astype(np.float32)
        labels = np.array(labels, dtype=int)

        # PCA → pca_components
        pca = PCA(n_components=self.pca_components, random_state=self.seed)
        X = pca.fit_transform(mfccs)
        # angle-encode range [0, π]
        X = np.pi * (X - X.min()) / (X.max() - X.min() + 1e-12)

        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            X, labels, test_size=0.2, stratify=labels, random_state=self.seed
        )
        print(f"   • Train size: {self.X_train.shape[0]:,}")
        print(f"   • Test  size: {self.X_test.shape[0]:,}")

    # 1.3  Feature-map circuit for angle encoding
    @staticmethod
    def _feature_map_vec(x, num_qubits):
        from qiskit import QuantumCircuit
        qc = QuantumCircuit(num_qubits)
        for i, theta in enumerate(x[:num_qubits]):
            qc.ry(theta, i)
        return qc

    # 1.4  Build teacher VQC
    def _build_teacher(self):
        ansatz = EfficientSU2(self.num_teacher_qubits, reps=2)
        self.teacher = VQC(
            feature_map=lambda x: self._feature_map_vec(x, self.num_teacher_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.5  Build student VQC
    def _build_student(self):
        ansatz = EfficientSU2(self.num_student_qubits, reps=1)
        self.student = VQC(
            feature_map=lambda x: self._feature_map_vec(x, self.num_student_qubits),
            ansatz=ansatz,
            optimizer="COBYLA",
            quantum_instance=AerSimulator(seed_simulator=self.seed),
            num_classes=self.n_classes,
        )

    # 1.6  Stage 1 – train teacher on true labels
    def train_teacher(self):
        print("\n⇨ Training teacher …")
        self.teacher.fit(self.X_train, self.y_train)

    # 1.7  Stage 2 – teacher generates pseudo labels
    def _pseudo_labels(self, X):
        return self.teacher.predict(X)   # hard labels for speed

    # 1.8  Stage 3 – train student on pseudo labels
    def train_student(self):
        print("\n⇨ Generating pseudo-labels …")
        pseudo = self._pseudo_labels(self.X_train)
        print("⇨ Training student on pseudo-labels …")
        self.student.fit(self.X_train, pseudo)

    # 1.9  Accuracy helper
    @staticmethod
    def _acc(model, X, y):
        return (model.predict(X) == y).mean()

    def report(self):
        print("\n=== Test accuracy ===")
        print(f"Teacher : {self._acc(self.teacher,  self.X_test, self.y_test):.3f}")
        print(f"Student : {self._acc(self.student,  self.X_test, self.y_test):.3f}")


# ---------------------------------------------------------------
# 2. Main
# ---------------------------------------------------------------
if __name__ == "__main__":
    server = DistillationServer(max_clips=None)   # set e.g. 1500 for a quick run
    server.train_teacher()
    server.train_student()
    server.report()
